{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using quaternions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quaternions are a widely-used representation in 3D computer graphics and robotics for representing the orientation of rigid bodies in 3D space. They provide a way to avoid issues related to gimbal lock, which can occur when using Euler angles.\n",
    "\n",
    "Kaolin provides a set of functions to convert between (batches of) quaternions and other 3D rotation representations like 3D rotation matrices, axis-angle representations, and 4D transformation matrices that include a translation and rotation component.\n",
    "\n",
    "Quaternions are represented in the form:\n",
    "```\n",
    "w + x*i + y*j + z*k\n",
    "```\n",
    "where `w` is the real component and `x`, `y`, and `z` are the imaginary components of the quaternion.\n",
    "Kaolin uses the scalar-last (`xyzw`) order, which is common in many graphics libraries.\n",
    "\n",
    "To learn more about quaternions and how they can be used to represent rotations, check out this [tutorial](https://eater.net/quaternions) from [Grant Sanderson (3blue1brown)](https://www.3blue1brown.com/) and [Ben Eater](https://eater.net/).\n",
    "\n",
    "Kaolin's functions provide a consistent interface across different 3D rotation representations, making it easy to switch between them when needed while retaining data batching. All operations are torch JIT compiled for acceleration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up the environment\n",
    "import torch\n",
    "from kaolin.math import quat\n",
    "\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quaternion basics\n",
    "\n",
    "Kaolin supports creating batches of quaternions from different representations. \n",
    "Here we create identity quaternions and create a quaternion from a torch tensor using the scalar-last (`xyzw`) representation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0., 1.],\n",
      "        [0., 0., 0., 1.]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# generate a batch of 2 identity quaternions\n",
    "identity = quat.quat_identity([2], device=device)\n",
    "print(identity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  2,   3,   4,   1],\n",
      "        [  0,  -4,  -2, -10]], device='cuda:0')\n",
      "real components:\n",
      "tensor([[  1],\n",
      "        [-10]], device='cuda:0')\n",
      "imaginary components:\n",
      "tensor([[ 2,  3,  4],\n",
      "        [ 0, -4, -2]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# manually create a batch of 2 quaternions\n",
    "quats = torch.tensor([[2, 3, 4, 1], [0, -4, -2, -10]], device=device)\n",
    "print(quats)\n",
    "\n",
    "# get the real (`w`) component of the two quaternions\n",
    "print(f\"real components:\\n{quat.quat_real(quats)}\")\n",
    "\n",
    "# get the imaginary (`xyz`) component of the two quaternions\n",
    "print(f\"imaginary components:\\n{quat.quat_imaginary(quats)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Kaolin supports common quaternion operations like normalizing to a quaternion to be a valid 3D rotation, multiplying two quaternions, or rotating a 3D point by a quaternion representing a rotation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  2.,   3.,   4.,   1.],\n",
      "        [  0.,  -4.,  -2., -10.]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# multiply the two quaternions\n",
    "quats_mul = quat.quat_mul(identity, quats)\n",
    "print(quats_mul)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.3651, 0.5477, 0.7303, 0.1826],\n",
      "        [-0.0000, 0.3651, 0.1826, 0.9129]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# ensure that the quaternion is a valid unit quaternion (norm of 1) and positive real part\n",
    "valid_3d = quat.quat_unit_positive(quats) \n",
    "print(valid_3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.8000, 2.0000, 2.6000],\n",
      "        [2.0000, 2.6000, 1.8000]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# rotate a point by the quaternion\n",
    "point = torch.tensor([[1, 2, 3], [1, 2, 3]], device=device)\n",
    "\n",
    "point_rotated = quat.quat_rotate(rotation=valid_3d, point=point)\n",
    "print(point_rotated)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quaternion conversions\n",
    "\n",
    "Kaolin supports conversions among a variety of 3D representations to and from quaternions:\n",
    "\n",
    "- 3x3 rotation matrix\n",
    "- 4x4 rotation matrix\n",
    "- angle-axis (angle and rotation axis)\n",
    "- Euclidean transformation (4x4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.6667,  0.1333,  0.7333],\n",
      "         [ 0.6667, -0.3333,  0.6667],\n",
      "         [ 0.3333,  0.9333,  0.1333]],\n",
      "\n",
      "        [[ 0.6667, -0.3333,  0.6667],\n",
      "         [ 0.3333,  0.9333,  0.1333],\n",
      "         [-0.6667,  0.1333,  0.7333]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# convert to 3x3 rotation matrix representation\n",
    "rot33 = quat.rot33_from_quat(quats)\n",
    "print(rot33)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.7744],\n",
      "        [0.8411]], device='cuda:0') tensor([[3.7139e-01, 5.5709e-01, 7.4278e-01],\n",
      "        [2.9200e-07, 8.9443e-01, 4.4721e-01]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# convert to angle-axis representation\n",
    "angle, axis = quat.angle_axis_from_quat(quats)\n",
    "print(angle, axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.7744],\n",
      "        [0.8411]], device='cuda:0') tensor([[3.7139e-01, 5.5709e-01, 7.4278e-01],\n",
      "        [2.9200e-07, 8.9443e-01, 4.4721e-01]], device='cuda:0')\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "# convert the 3x3 rotation matrix to the matching angle-axis representation\n",
    "angle2, axis2 = quat.angle_axis_from_rot33(rot33)\n",
    "print(angle2, axis2)\n",
    "\n",
    "# verify conversions retained the correct values\n",
    "print(torch.allclose(angle, angle2))\n",
    "print(torch.allclose(axis, axis2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.6667,  0.1333,  0.7333,  0.0000],\n",
      "         [ 0.6667, -0.3333,  0.6667,  0.0000],\n",
      "         [ 0.3333,  0.9333,  0.1333,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  1.0000]],\n",
      "\n",
      "        [[ 0.6667, -0.3333,  0.6667,  0.0000],\n",
      "         [ 0.3333,  0.9333,  0.1333,  0.0000],\n",
      "         [-0.6667,  0.1333,  0.7333,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  1.0000]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# convert to 4x4 rotation matrix\n",
    "rot44 = quat.rot44_from_quat(quats)\n",
    "print(rot44)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.6667,  0.1333,  0.7333,  1.0000],\n",
      "         [ 0.6667, -0.3333,  0.6667,  2.0000],\n",
      "         [ 0.3333,  0.9333,  0.1333,  3.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  1.0000]],\n",
      "\n",
      "        [[ 0.6667, -0.3333,  0.6667,  1.0000],\n",
      "         [ 0.3333,  0.9333,  0.1333,  2.0000],\n",
      "         [-0.6667,  0.1333,  0.7333,  3.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  1.0000]]], device='cuda:0')\n",
      "rotation component:\n",
      "tensor([[[-0.6667,  0.1333,  0.7333],\n",
      "         [ 0.6667, -0.3333,  0.6667],\n",
      "         [ 0.3333,  0.9333,  0.1333]],\n",
      "\n",
      "        [[ 0.6667, -0.3333,  0.6667],\n",
      "         [ 0.3333,  0.9333,  0.1333],\n",
      "         [-0.6667,  0.1333,  0.7333]]], device='cuda:0')\n",
      "translation component:\n",
      "tensor([[1., 2., 3.],\n",
      "        [1., 2., 3.]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# compose a Euclidean transform matrix of the quaternion rotation and a translation\n",
    "euclidean = quat.euclidean_from_rotation_translation(r=quats, t=torch.tensor([[1, 2, 3]]))\n",
    "print(euclidean)\n",
    "\n",
    "print(f\"rotation component:\\n{quat.euclidean_rotation_matrix(euclidean)}\")\n",
    "print(f\"translation component:\\n{quat.euclidean_translation_vector(euclidean)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
